package net.kaaass.tools.gradequery.captcha;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.util.ArrayList;
import java.util.List;

/**
 * 推断数字内容
 */
public class Infer {
    private static List<List<BufferedImage>> model;

    static { // 装载模型
        try {
            model = new ArrayList<>();
            List<BufferedImage> list;
            for (int i = 0; i <= 3; i++) {
                list = new ArrayList<>();
                for (int ii = 0; ii <= 8; ii++) {
                    list.add(ImageIO.read(new File("captcha/" + i + "/" + ii + ".png")));
                }
                model.add(list);
            }
        } catch (Exception e) {
            System.out.println("Error occurred in reading captcha model: " + e + ", " + e.getLocalizedMessage());
        }
    }

    public static List<BufferedImage> splitImage(BufferedImage image) throws Exception {
        List<BufferedImage> digitImageList = new ArrayList<>();
        digitImageList.add(image.getSubimage(0, 0, 16, 40));
        digitImageList.add(image.getSubimage(16, 0, 19, 40));
        digitImageList.add(image.getSubimage(36, 0, 22, 40));
        digitImageList.add(image.getSubimage(58, 0, 22, 40));
        return digitImageList;
    }

    private static int diff(BufferedImage img_a, BufferedImage img_b) {
        int diff = 0;
        int width = img_a.getWidth();
        int height = img_a.getHeight();
        for (int x = 0; x < width; ++x) {
            for (int y = 0; y < height; ++y) {
                if (img_a.getRGB(x, y) != img_b.getRGB(x, y)) diff++;
            }
        }
        return diff;
    }

    public static String read(File file) throws Exception {
        BufferedImage image = ImageIO.read(file);
        return read(image);
    }

    public static String read(BufferedImage image) throws Exception {
        Filtering.binaryzation(image);
        List<BufferedImage> imgs = Infer.splitImage(image);
        BufferedImage cur;
        String result = "";
        int cur_diff, min_diff, min;
        for (int idx = 0; idx <= 3; idx++) {
            cur = imgs.get(idx);
            min_diff = 999;  // 初始化一个极大值
            min = 0;
            for (int i = 0; i <= 8; i++) {
                cur_diff = diff(cur, model.get(idx).get(i));
                // System.out.println("Diff for image: "+idx+", "+i+", result: "+cur_diff);
                if (cur_diff < min_diff) {
                    min_diff = cur_diff;
                    min = i;
                }
            }
            result += min;
        }
        return result;
    }
}
